Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demote(B)Float16 pass: only keep enabled for PPC. #55486

Merged
merged 1 commit into from
Aug 17, 2024
Merged

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Aug 13, 2024

LLVM should handle this properly now for everything but PPC (where BFoat16 isn't supported anyway).

I considered stripping the bf16 bits from the pass, but went for the more conservative change for now in case we discover issues lurking in targets that aren't covered by CI.

Fixes #55479

@maleadt maleadt added the compiler:codegen Generation of LLVM IR and native code label Aug 13, 2024
@maleadt maleadt marked this pull request as draft August 13, 2024 19:16
@gbaraldi
Copy link
Member

gbaraldi commented Aug 13, 2024

So the backends are now doing the right thing by default? Except PPC?

@oscardssmith
Copy link
Member

@gbaraldi yes. see #55479 (comment)

@maleadt maleadt force-pushed the tb/nerf_demote_pass branch from 230e4a8 to 3c4cc92 Compare August 14, 2024 06:13
@maleadt maleadt marked this pull request as ready for review August 14, 2024 08:00
@maleadt
Copy link
Member Author

maleadt commented Aug 14, 2024

@giordano Can you check this now does the right thing on Grace?

@giordano
Copy link
Contributor

On this PR I get

julia> code_llvm(NTuple{2,BFloat16}; debuginfo=:none) do a, b sqrt(a * a + b * b) end
; Function Signature: var"#7"(Core.BFloat16, Core.BFloat16)
define bfloat @"julia_#7_3558"(bfloat %"a::BFloat16", bfloat %"b::BFloat16") #0 {
top:
  %bitcast_coercion = bitcast bfloat %"a::BFloat16" to i16
  %0 = zext i16 %bitcast_coercion to i32
  %1 = shl nuw i32 %0, 16
  %bitcast_coercion2 = bitcast i32 %1 to float
  %2 = fmul float %bitcast_coercion2, %bitcast_coercion2
  %3 = fcmp ord float %2, 0.000000e+00
  br i1 %3, label %L13, label %L32

L13:                                              ; preds = %top
  %bitcast_coercion104 = bitcast float %2 to i32
  %4 = lshr i32 %bitcast_coercion104, 16
  %5 = and i32 %4, 1
  %narrow = add i32 %bitcast_coercion104, 32767
  %6 = add i32 %narrow, %5
  %7 = and i32 %6, -65536
  %8 = bitcast i32 %7 to float
  br label %L32

L32:                                              ; preds = %L13, %top
  %bitcast_coercion28 = phi float [ %8, %L13 ], [ 0x7FF8000000000000, %top ]
  %bitcast_coercion11 = bitcast bfloat %"b::BFloat16" to i16
  %9 = zext i16 %bitcast_coercion11 to i32
  %10 = shl nuw i32 %9, 16
  %bitcast_coercion16 = bitcast i32 %10 to float
  %11 = fmul float %bitcast_coercion16, %bitcast_coercion16
  %12 = fcmp ord float %11, 0.000000e+00
  br i1 %12, label %L44, label %L63

L44:                                              ; preds = %L32
  %bitcast_coercion84 = bitcast float %11 to i32
  %13 = lshr i32 %bitcast_coercion84, 16
  %14 = and i32 %13, 1
  %narrow126 = add i32 %bitcast_coercion84, 32767
  %15 = add i32 %narrow126, %14
  %16 = and i32 %15, -65536
  %17 = bitcast i32 %16 to float
  br label %L63

L63:                                              ; preds = %L44, %L32
  %bitcast_coercion35 = phi float [ %17, %L44 ], [ 0x7FF8000000000000, %L32 ]
  %18 = fadd float %bitcast_coercion28, %bitcast_coercion35
  %19 = fcmp ord float %18, 0.000000e+00
  br i1 %19, label %L94, label %L102

L94:                                              ; preds = %L63
  %bitcast_coercion64 = bitcast float %18 to i32
  %20 = lshr i32 %bitcast_coercion64, 16
  %21 = and i32 %20, 1
  %narrow127 = add i32 %bitcast_coercion64, 32767
  %22 = add i32 %narrow127, %21
  %23 = and i32 %22, -65536
  %24 = bitcast i32 %23 to float
  %25 = fcmp uge float %24, 0.000000e+00
  br i1 %25, label %L102, label %L100

L100:                                             ; preds = %L94
  call void @j_throw_complex_domainerror_3570(ptr nonnull @"jl_sym#sqrt#3571.jit", float %24) #10
  unreachable

L102:                                             ; preds = %L94, %L63
  %bitcast_coercion44130 = phi float [ %24, %L94 ], [ 0x7FF8000000000000, %L63 ]
  %26 = call float @llvm.sqrt.f32(float %bitcast_coercion44130)
  %27 = fcmp ord float %26, 0.000000e+00
  br i1 %27, label %L107, label %L126

L107:                                             ; preds = %L102
  %bitcast_coercion53 = bitcast float %26 to i32
  %28 = lshr i32 %bitcast_coercion53, 16
  %29 = and i32 %28, 1
  %narrow128 = add nuw nsw i32 %29, 32767
  %30 = zext nneg i32 %narrow128 to i64
  %31 = zext i32 %bitcast_coercion53 to i64
  %32 = add nuw nsw i64 %30, %31
  %33 = lshr i64 %32, 16
  %34 = trunc i64 %33 to i16
  %bitcast_coercion62 = bitcast i16 %34 to bfloat
  br label %L126

L126:                                             ; preds = %L107, %L102
  %value_phi51 = phi bfloat [ %bitcast_coercion62, %L107 ], [ 0xR7FC0, %L102 ]
  ret bfloat %value_phi51
}

which is exactly same IR I get on nightly (modulo the gensymed function names). It's not doing native operations in bfloat, but as said in #55417 (comment) it doesn't appear bf16 on aarch64 does much

@maleadt
Copy link
Member Author

maleadt commented Aug 14, 2024

Ah, that's because BFloat16s needs to be updated: https://github.com/JuliaMath/BFloat16s.jl/blob/2266cc578d973bbd27fde7fc25a3d6dea0160f80/src/bfloat16.jl#L20-L49

And FWIW, we need to careful about extending that check, because e.g. AArch64 doesn't even support bf16 arithmetic on LLVM 18, only on 19: https://godbolt.org/z/vPhbYWrT4

@maleadt
Copy link
Member Author

maleadt commented Aug 14, 2024

With a Julia build against LLVM 19 + JuliaMath/BFloat16s.jl#77 on macos-aarch64 (m3):

julia> code_llvm(NTuple{2,BFloat16}; debuginfo=:none) do a, b sqrt(a * a + b * b) end
define bfloat @"julia_#1_2245"(bfloat %"a::BFloat16", bfloat %"b::BFloat16") #0 {
L9:
  %0 = fmul bfloat %"a::BFloat16", %"a::BFloat16"
  %1 = fmul bfloat %"b::BFloat16", %"b::BFloat16"
  %2 = fadd bfloat %0, %1
  %3 = fpext bfloat %2 to float
  %4 = call float @llvm.sqrt.f32(float %3)
  %5 = fptrunc float %4 to bfloat
  ret bfloat %5
}

The generated code of course still contains conversions, as like you mentioned there's no scalar bf16 instructions yet:

julia> code_native(NTuple{2,BFloat16}; debuginfo=:none) do a, b sqrt(a * a + b * b) end
	fmov	w8, s0
	lsl	w8, w8, #16
	fmov	s0, w8
	fmul	s0, s0, s0
	bfcvt	h0, s0
	fmov	w8, s1
	lsl	w8, w8, #16
	fmov	s1, w8
	fmul	s1, s1, s1
	bfcvt	h1, s1
	fmov	w8, s1
	lsl	w8, w8, #16
	fmov	s1, w8
	fmov	w8, s0
	lsl	w8, w8, #16
	fmov	s0, w8
	fadd	s0, s0, s1
	bfcvt	h0, s0
	fmov	w8, s0
	lsl	w8, w8, #16
	fmov	s0, w8
	fsqrt	s0, s0
	bfcvt	h0, s0
	ret

@maleadt maleadt merged commit faa6095 into master Aug 17, 2024
9 checks passed
@maleadt maleadt deleted the tb/nerf_demote_pass branch August 17, 2024 15:00
KristofferC pushed a commit that referenced this pull request Sep 12, 2024
LLVM should handle this properly now for everything but PPC (where
BFoat16 isn't supported anyway).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:codegen Generation of LLVM IR and native code float16
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consider dropping bfloat16 demotion pass
4 participants